/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Template for GEMM performing a reduction over K partitions in
   parallel.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"

#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm.h"

#include "cutlass/gemm/kernel/default_gemm_splitk_parallel.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"

#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/reduction/kernel/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"

////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace device {

////////////////////////////////////////////////////////////////////////////////

/*!
  Gemm device-level operator performing parallel reduction over the K partition.

*/
template <
        /// Element type for A matrix operand
        typename ElementA_,
        /// Layout type for A matrix operand
        typename LayoutA_,
        /// Element type for B matrix operand
        typename ElementB_,
        /// Layout type for B matrix operand
        typename LayoutB_,
        /// Element type for C and D matrix operands
        typename ElementC_,
        /// Layout type for C and D matrix operands
        typename LayoutC_,
        /// Element type for internal accumulation
        typename ElementAccumulator_ = ElementC_,
        /// Operator class tag
        typename OperatorClass_ = arch::OpClassSimt,
        /// Tag indicating architecture to tune for
        typename ArchTag_ = arch::Sm70,
        /// Threadblock-level tile size (concept: GemmShape)
        typename ThreadblockShape_ = typename DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::ThreadblockShape,
        /// Warp-level tile size (concept: GemmShape)
        typename WarpShape_ = typename DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::WarpShape,
        /// Instruction-level tile size (concept: GemmShape)
        typename InstructionShape_ = typename DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::InstructionShape,
        /// Epilogue output operator
        typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::EpilogueOutputOp,
        /// Epilogue output operator
        typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert<
                ElementAccumulator_,
                DefaultGemmConfiguration<
                        OperatorClass_, ArchTag_, ElementA_, ElementB_,
                        ElementAccumulator_,
                        ElementAccumulator_>::EpilogueOutputOp::kCount,
                ElementAccumulator_>,
        /// Reduction operator
        typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd<
                ElementAccumulator_,
                typename EpilogueOutputOp_::ElementAccumulator,
                EpilogueOutputOp_::kCount>,
        /// Threadblock-level swizzling operator
        typename ThreadblockSwizzle_ =
                threadblock::GemmSplitKHorizontalThreadblockSwizzle,
        /// Number of stages used in the pipelined mainloop
        int Stages = DefaultGemmConfiguration<OperatorClass_, ArchTag_,
                                              ElementA_, ElementB_, ElementC_,
                                              ElementAccumulator_>::kStages,
        /// Access granularity of A matrix in units of elements
        int AlignmentA = DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::kAlignmentA,
        /// Access granularity of B matrix in units of elements
        int AlignmentB = DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::kAlignmentB,
        /// Operation performed by GEMM
        typename Operator_ = typename DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::Operator>
class GemmSplitKParallel {
public:
    using ElementA = ElementA_;
    using LayoutA = LayoutA_;
    using ElementB = ElementB_;
    using LayoutB = LayoutB_;
    using ElementC = ElementC_;
    using LayoutC = LayoutC_;
    using ElementAccumulator = ElementAccumulator_;
    using OperatorClass = OperatorClass_;
    using ArchTag = ArchTag_;
    using ThreadblockShape = ThreadblockShape_;
    using WarpShape = WarpShape_;
    using InstructionShape = InstructionShape_;
    using ConvertScaledOp = ConvertScaledOp_;
    using EpilogueOutputOp = EpilogueOutputOp_;
    using ReductionOp = ReductionOp_;
    using ThreadblockSwizzle = ThreadblockSwizzle_;
    using Operator = Operator_;
    static int const kStages = Stages;
    static int const kAlignmentA = AlignmentA;
    static int const kAlignmentB = AlignmentB;
    static int const kAlignmentC = EpilogueOutputOp::kCount;
    static ComplexTransform const kTransformA = ComplexTransform::kNone;
    static ComplexTransform const kTransformB = ComplexTransform::kNone;

    /// GEMM kernel
    using GemmKernel = typename kernel::DefaultGemmSplitKParallel<
            ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
            ElementAccumulator, LayoutC, ElementAccumulator, OperatorClass,
            ArchTag, ThreadblockShape, WarpShape, InstructionShape,
            ConvertScaledOp, ThreadblockSwizzle, kStages, Operator>::GemmKernel;

    /// Reduction kernel
    using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
            cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
            EpilogueOutputOp, ReductionOp>;

    //
    //
    //

    /// Argument structure
    struct Arguments {
        //
        // Data members
        //

        GemmCoord problem_size;
        TensorRef<ElementA const, LayoutA> ref_A;
        TensorRef<ElementB const, LayoutB> ref_B;
        TensorRef<ElementC const, LayoutC> ref_C;
        TensorRef<ElementC, LayoutC> ref_D;
        typename EpilogueOutputOp::Params epilogue;
        int split_k_slices;
        typename ConvertScaledOp::Params convert;
        typename ReductionOp::Params reduction;

        //
        // Methods
        //

        /// Default ctor
        CUTLASS_HOST_DEVICE
        Arguments() {}

        /// Constructs an Arguments structure
        CUTLASS_HOST_DEVICE
        Arguments(GemmCoord problem_size_,
                  TensorRef<ElementA const, LayoutA> ref_A_,
                  TensorRef<ElementB const, LayoutB> ref_B_,
                  TensorRef<ElementC const, LayoutC> ref_C_,
                  TensorRef<ElementC, LayoutC> ref_D_,
                  typename EpilogueOutputOp::Params epilogue_ =
                          typename EpilogueOutputOp::Params(),
                  int split_k_slices = 1,
                  typename ConvertScaledOp::Params convert_ =
                          typename ConvertScaledOp::Params(),
                  typename ReductionOp::Params reduction_ =
                          typename ReductionOp::Params())
                : problem_size(problem_size_),
                  ref_A(ref_A_),
                  ref_B(ref_B_),
                  ref_C(ref_C_),
                  ref_D(ref_D_),
                  epilogue(epilogue_),
                  split_k_slices(split_k_slices),
                  convert(convert_),
                  reduction(reduction_) {}
    };

private:
    /// Kernel parameters object
    typename GemmKernel::Params gemm_params_;

    /// Reduction kernel parameters object
    typename ReductionKernel::Params reduction_params_;

public:
    /// Constructs the GEMM.
    GemmSplitKParallel() {}

    /// Determines whether the GEMM can execute the given problem.
    static Status can_implement(Arguments const& args) {
        // TODO

        return Status::kSuccess;
    }

    /// Gets the workspace size
    static size_t get_workspace_size(Arguments const& args) {
        // Determine grid shape
        ThreadblockSwizzle threadblock_swizzle;

        cutlass::gemm::GemmCoord grid_shape =
                threadblock_swizzle.get_tiled_shape(
                        args.problem_size,
                        {ThreadblockShape::kM, ThreadblockShape::kN,
                         ThreadblockShape::kK},
                        args.split_k_slices);

        return sizeof(ElementAccumulator_) * size_t(args.problem_size.m()) *
               size_t(args.problem_size.n()) * grid_shape.k();
    }

    /// Initializes GEMM state from arguments.
    Status initialize(Arguments const& args, void* workspace) {
        // Determine grid shape
        ThreadblockSwizzle threadblock_swizzle;

        cutlass::gemm::GemmCoord grid_shape =
                threadblock_swizzle.get_tiled_shape(
                        args.problem_size,
                        {ThreadblockShape::kM, ThreadblockShape::kN,
                         ThreadblockShape::kK},
                        args.split_k_slices);

        // Define a reference to the workspace - this is an aligned region in
        // device memory.
        if (!workspace) {
            return Status::kErrorWorkspaceNull;
        }

        TensorRef<ElementAccumulator_, layout::RowMajor> ref_workspace(
                static_cast<ElementAccumulator_*>(workspace),
                args.problem_size.n());

        int64_t partition_stride =
                int64_t(args.problem_size.m()) * int64_t(args.problem_size.n());

        // Initialize the Params structure
        gemm_params_ = typename GemmKernel::Params{args.problem_size,
                                                   grid_shape,
                                                   args.ref_A.non_const_ref(),
                                                   args.ref_B.non_const_ref(),
                                                   ref_workspace,
                                                   args.convert,
                                                   partition_stride};

        reduction_params_ = typename ReductionKernel::Params(
                args.problem_size.mn(), grid_shape.k(), partition_stride,
                ref_workspace, args.ref_D, args.ref_C.non_const_ref(),
                args.epilogue);

        return Status::kSuccess;
    }

    /// Lightweight update given a subset of arguments
    Status update(Arguments const& args, void* workspace = nullptr) {
        if (!workspace) {
            return Status::kErrorWorkspaceNull;
        }

        gemm_params_.ref_A.reset(args.ref_A.data());
        gemm_params_.ref_B.reset(args.ref_B.data());
        gemm_params_.ref_D.reset(workspace);

        reduction_params_.ref_D.reset(args.ref_D.data());
        reduction_params_.ref_C.reset(args.ref_C.data());

        return Status::kSuccess;
    }

    /// Runs the kernel using initialized state.
    Status run(cudaStream_t stream = nullptr) {
        //
        // Launch GEMM kernel
        //

        ThreadblockSwizzle threadblock_swizzle;

        dim3 grid = threadblock_swizzle.get_grid_shape(
                gemm_params_.grid_tiled_shape);
        dim3 block(GemmKernel::kThreadCount, 1, 1);

        cudaError_t result;

        int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
        if (smem_size >= (48 << 10)) {
            result = cudaFuncSetAttribute(
                    Kernel<GemmKernel>,
                    cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);

            if (result != cudaSuccess) {
                return Status::kErrorInternal;
            }

            result = cudaFuncSetAttribute(
                    Kernel<GemmKernel>,
                    cudaFuncAttributePreferredSharedMemoryCarveout, 100);

            if (result != cudaSuccess) {
                return Status::kErrorInternal;
            }
        }

        Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(gemm_params_);

        result = cudaGetLastError();

        if (result != cudaSuccess) {
            return Status::kErrorInternal;
        }

        //
        // Launch reduction kernel
        //

        block = ReductionKernel::block_shape();
        grid = ReductionKernel::grid_shape(gemm_params_.problem_size.mn());

        Kernel<ReductionKernel><<<grid, block, 0, stream>>>(reduction_params_);

        result = cudaGetLastError();

        if (result != cudaSuccess) {
            return Status::kErrorInternal;
        }

        return result == cudaSuccess ? Status::kSuccess
                                     : Status::kErrorInternal;
    }

    /// Runs the kernel using initialized state.
    Status operator()(cudaStream_t stream = nullptr) { return run(stream); }

    /// Runs the kernel using initialized state.
    Status operator()(Arguments const& args, void* workspace = nullptr,
                      cudaStream_t stream = nullptr) {
        Status status = initialize(args, workspace);

        if (status == Status::kSuccess) {
            status = run(stream);
        }

        return status;
    }
};

////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for column-major output
template <
        /// Element type for A matrix operand
        typename ElementA_,
        /// Layout type for A matrix operand
        typename LayoutA_,
        /// Element type for B matrix operand
        typename ElementB_,
        /// Layout type for B matrix operand
        typename LayoutB_,
        /// Element type for C and D matrix operands
        typename ElementC_,
        /// Element type for internal accumulation
        typename ElementAccumulator_,
        /// Operator class tag
        typename OperatorClass_,
        /// Tag indicating architecture to tune for
        typename ArchTag_,
        /// Threadblock-level tile size (concept: GemmShape)
        typename ThreadblockShape_,
        /// Warp-level tile size (concept: GemmShape)
        typename WarpShape_,
        /// Instruction-level tile size (concept: GemmShape)
        typename InstructionShape_,
        /// Epilogue output operator
        typename EpilogueOutputOp_,
        /// Epilogue output operator
        typename ConvertScaledOp_,
        /// Reduction operator
        typename ReductionOp_,
        /// Threadblock-level swizzling operator
        typename ThreadblockSwizzle_,
        /// Number of stages used in the pipelined mainloop
        int Stages, int kAlignmentA, int kAlignmentB,
        /// Operation performed by GEMM
        typename Operator_>
class GemmSplitKParallel<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
                         layout::ColumnMajor, ElementAccumulator_,
                         OperatorClass_, ArchTag_, ThreadblockShape_,
                         WarpShape_, InstructionShape_, EpilogueOutputOp_,
                         ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_,
                         Stages, kAlignmentA, kAlignmentB, Operator_> {
public:
    using ElementA = ElementA_;
    using LayoutA = LayoutA_;
    using ElementB = ElementB_;
    using LayoutB = LayoutB_;
    using ElementC = ElementC_;
    using LayoutC = layout::ColumnMajor;
    using ElementAccumulator = ElementAccumulator_;
    using OperatorClass = OperatorClass_;
    using ArchTag = ArchTag_;
    using ThreadblockShape = ThreadblockShape_;
    using WarpShape = WarpShape_;
    using InstructionShape = InstructionShape_;
    using ConvertScaledOp = ConvertScaledOp_;
    using EpilogueOutputOp = EpilogueOutputOp_;
    using ReductionOp = ReductionOp_;
    using ThreadblockSwizzle = ThreadblockSwizzle_;
    using Operator = Operator_;
    static int const kStages = Stages;

    using UnderlyingOperator = GemmSplitKParallel<
            ElementB, typename layout::LayoutTranspose<LayoutB>::type, ElementA,
            typename layout::LayoutTranspose<LayoutA>::type, ElementC,
            layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag,
            ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
            ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages,
            kAlignmentA, kAlignmentB, Operator>;

    using UnderlyingArguments = typename UnderlyingOperator::Arguments;
    using GemmKernel = typename UnderlyingOperator::GemmKernel;
    using ReductionKernel = typename UnderlyingOperator::ReductionKernel;

    /// Argument structure
    struct Arguments {
        //
        // Data members
        //

        GemmCoord problem_size;
        TensorRef<ElementA const, LayoutA> ref_A;
        TensorRef<ElementB const, LayoutB> ref_B;
        TensorRef<ElementC const, LayoutC> ref_C;
        TensorRef<ElementC, LayoutC> ref_D;
        typename EpilogueOutputOp::Params epilogue;
        int split_k_slices;
        typename ConvertScaledOp::Params convert;
        typename ReductionOp::Params reduction;

        //
        // Methods
        //

        /// Default ctor
        CUTLASS_HOST_DEVICE
        Arguments() {}

        /// Constructs an Arguments structure
        CUTLASS_HOST_DEVICE
        Arguments(GemmCoord problem_size_,
                  TensorRef<ElementA const, LayoutA> ref_A_,
                  TensorRef<ElementB const, LayoutB> ref_B_,
                  TensorRef<ElementC const, LayoutC> ref_C_,
                  TensorRef<ElementC, LayoutC> ref_D_,
                  typename EpilogueOutputOp::Params epilogue_ =
                          typename EpilogueOutputOp::Params(),
                  int split_k_slices = 1,
                  typename ConvertScaledOp::Params convert_ =
                          typename ConvertScaledOp::Params(),
                  typename ReductionOp::Params reduction_ =
                          typename ReductionOp::Params())
                : problem_size(problem_size_),
                  ref_A(ref_A_),
                  ref_B(ref_B_),
                  ref_C(ref_C_),
                  ref_D(ref_D_),
                  epilogue(epilogue_),
                  split_k_slices(split_k_slices),
                  convert(convert_),
                  reduction(reduction_) {}
    };

private:
    /// Kernel parameters object
    UnderlyingOperator underlying_operator_;

public:
    /// Constructs the GEMM.
    GemmSplitKParallel() {}

    /// Helper to construct a transposed equivalent for the underying GEMM
    /// operator
    static UnderlyingArguments to_underlying_arguments(Arguments const& args) {
        return UnderlyingArguments(
                {args.problem_size.n(), args.problem_size.m(),
                 args.problem_size.k()},
                {args.ref_B.data(), args.ref_B.stride(0)},
                {args.ref_A.data(), args.ref_A.stride(0)},
                {args.ref_C.data(), args.ref_C.stride(0)},
                {args.ref_D.data(), args.ref_D.stride(0)}, args.epilogue,
                args.split_k_slices, args.convert, args.reduction);
    }

    /// Determines whether the GEMM can execute the given problem.
    static Status can_implement(Arguments const& args) {
        return UnderlyingOperator::can_implement(to_underlying_arguments(args));
    }

    /// Gets the workspace size
    static size_t get_workspace_size(Arguments const& args) {
        return UnderlyingOperator::get_workspace_size(
                to_underlying_arguments(args));
    }

    /// Initializes GEMM state from arguments.
    Status initialize(Arguments const& args, void* workspace) {
        return underlying_operator_.initialize(to_underlying_arguments(args),
                                               workspace);
    }

    /// Lightweight update given a subset of arguments
    Status update(Arguments const& args, void* workspace = nullptr) {
        return underlying_operator_.update(to_underlying_arguments(args),
                                           workspace);
    }

    /// Runs the kernel using initialized state.
    Status run(cudaStream_t stream = nullptr) {
        return underlying_operator_.run(stream);
    }

    /// Runs the kernel using initialized state.
    Status operator()(cudaStream_t stream = nullptr) { return run(stream); }

    /// Runs the kernel using initialized state.
    Status operator()(Arguments const& args, void* workspace = nullptr,
                      cudaStream_t stream = nullptr) {
        Status status = initialize(args, workspace);

        if (status == Status::kSuccess) {
            status = run(stream);
        }

        return status;
    }
};

////////////////////////////////////////////////////////////////////////////////

}  // namespace device
}  // namespace gemm
}  // namespace cutlass

////////////////////////////////////////////////////////////////////////////////
